// SymmetryExperimenterPlugin.cpp starts

/*******************************************************
* Created: 7/29/00, CGP
* 	Plugin for symmetry experiment
*   Rat is learning
*		if-WEST-then-NORTH
*		if-SOUTH-then-EAST
* Modified: 9/25/00, CGP
*	Starting modifications to use current plugin API.
* Modified: 11/13/00, CGP
*	Added in criterion for training & testing.
*******************************************************/

#include <Message.h>
#include <stdio.h>
#include "SymmetryExperimenterPlugin.h"
#include "EnvironmentConstants.h"
#include "Debug.h"
#include "UserEnvMessage.h"
#include "PluginSettingsNames.h"
#include "DatabaseFieldNames.h"
#include "PerfQueue.h"

// Rat has to get 87.5% correct over at least 64 trials
#define		PERFORMANCE_CRITERION		0.875

// Number of trials overwhich performance is assessed
#define		PERFORMANCE_WINDOW			64

// 4 arm maze
#define		NORTH		0
#define		EAST		1
#define		SOUTH		2
#define		WEST		3

ExperimenterPlugin *instantiate_exp(PortMessage *expRatPort,
	PortMessage *envExpPort, DatabaseMessage *dbMsg, UserEnvMessage *userEnvMsg,
		DebugServer *bugServer) {
   return new SymmetryExperimenterPlugin(expRatPort, envExpPort, dbMsg, userEnvMsg, bugServer);
}

SymmetryExperimenterPlugin::SymmetryExperimenterPlugin(PortMessage *expRatPort,
	PortMessage *envExpPort, DatabaseMessage *dbMsg, UserEnvMessage *userEnvMsg,
	DebugServer *bugServer):
		ExperimenterPlugin(expRatPort, envExpPort, dbMsg, userEnvMsg, bugServer),
		perfQueue(PERFORMANCE_WINDOW)
{
	// Temporary: will access settings from the database when we have a GUI dialog that enables
	// changing the plugin settings.
	settings = new BMessage();
	
	// create default plugin settings, if none given
	if (settings->CountNames(B_ANY_TYPE) == 0) {
		Debug('n', "SymmetryExperimenterPlugin::SymmetryExperimenterPlugin: Null default settings");
		this->settings->AddInt16(NUMBER_OF_MAZE_ARMS, 4);
	} // default constructor already established value for this->settings otherwise

	// NOTE: We can't set up the maze here because we don't get access to the environment
	// in the constructor. Put the maze setup in the "Start" function.
}

SymmetryExperimenterPlugin::~SymmetryExperimenterPlugin()
{
}

void SymmetryExperimenterPlugin::AddPerformanceRecord(int trialType,
	int armNumber, int foundFood, int consumedFood)
{
	BMessage *perfDataRecord = new BMessage();
	
	// Database only allows int16's right now as data type
	perfDataRecord->AddInt16(DBFLD_RECORD_TYPE, 3);
	perfDataRecord->AddInt16(DBFLD_RECORD_SUB_TYPE, 6);
	perfDataRecord->AddInt16(DBFLD_TRIAL_TYPE, trialType);
	perfDataRecord->AddInt16(DBFLD_RAT_NUMBER, ratNo);
	perfDataRecord->AddInt16(DBFLD_TRIAL_NUMBER, trialNumber);
	perfDataRecord->AddInt16(DBFLD_ARM_NUMBER, armNumber);
	perfDataRecord->AddInt16(DBFLD_CONTAINED_FOOD, foundFood);
	perfDataRecord->AddInt16(DBFLD_CONSUMED_FOOD, consumedFood);
	EXP_AddDatabaseRecord(perfDataRecord);
	delete perfDataRecord;
	
	Debug('n', "SymmetryExperimenterPlugin::AddPerformanceRecord: Saved record");
}

// Set up the maze for the rat plugin

void SymmetryExperimenterPlugin::SetupForTrial(int startArm, int goalArm, int otherArm)
{
	EXP_CloseDoor(WEST); EXP_CloseDoor(NORTH); EXP_CloseDoor(EAST); EXP_CloseDoor(SOUTH);
	EXP_RemoveFood(WEST); EXP_RemoveFood(NORTH); EXP_RemoveFood(EAST); EXP_RemoveFood(SOUTH);
	EXP_PlaceFood(startArm);
	EXP_PlaceFood(goalArm);
	EXP_OpenDoor(startArm);
	EXP_PlaceRatAt(CENTER);
}

// Data from Ish's RATNET set. 288 training & testing trials. Each char gives the
// start arm for the trial. Even

static char train[] = 
"SWSSWWSWSWWSSWSWSWSWWSWSSWWSSWSWWSWSWSSWSWWSWSWSWSWSSWS\
SSWWWWSSSWWWSSSWWSWWWSSWSSWSWSWWSWSWWSSSWWWSWSSSWSWSWWS\
WSWSSWWSSWWSSWWSWSWWSSSSWWWSWSWWSSSSWWSWSSWWSWSWWSSWSSW\
SWWSWWSWSWSWWSSSWSWWSWWSSSWWWSSWSWSSWWSWSWSSWWWSSWSWSWS\
WSWSSWWSWSSWSWWWSSSWSSSWWWSWWSSWWWSSSWSWWSWSSSWWWSSWWWS\
SWWSWSSSSSWWW";

static char test[] =
"ENEENNEENENNEENNNENNEEENENNNEEEENNNENENEENENNNEENNENEEN\
ENENENNEENENNEENEEENNNENEENENNEEENNEENNENEENENNENENENNN\
ENEENNEENEENNEENENENNENEENENNENEENEENENNNNEENENENEENEEN\
NENNNNEEENENENENNEENEENENENNNENEENNENEEENEENNEENNENNEEN\
ENEENNNEEEENNNENNENEEENNENNNENEEENNNEENENEENENNENENEEEN\
NNEENNEENEENN";

int GetMaxTrainTrial()
{
	return strlen(train) - 1;
}

int GetMaxTestTrial()
{
	return strlen(test) - 1;
}
	
// Get the trial information for a particular trial number. That is, find out the
// start arm, goal arm, and incorrect choice (otherArm) arm. TrainOrTest is "Train"
// for training, and "Test" for testing. trialNumber starts at 0 and goes up to GetMaxTrainTrial()
// or GetMaxTestTrial().
void SymmetryExperimenterPlugin::GetTrialInformation(int *startArm, int *goalArm,
	int *otherArm, int *trialType)
{
	if (strcmp(mode, "Train") == 0) {
		if (trialNumber > GetMaxTrainTrial()) {
			Debug('e', "SymmetryExperimenterPlugin::GetTrialInformation: Max train trial exceeded");
			return;
		}
		if (train[trialNumber] == 'S') {
			*trialType = 1;
			*startArm = SOUTH;
			*goalArm = EAST;
			*otherArm = NORTH;
		} else {
			*trialType = 2;
			*startArm = WEST;
			*goalArm = NORTH;
			*otherArm = EAST;
		}
	} else if (strcmp(mode, "Test") == 0) {
		if (trialNumber > GetMaxTestTrial()) {
			Debug('e', "SymmetryExperimenterPlugin::GetTrialInformation: Max test trial exceeded");
			return;
		}
		if (test[trialNumber] == 'E') {
			*trialType = 3;
			*startArm = EAST;
			*goalArm = SOUTH;
			*otherArm = WEST;
		} else {
			*trialType = 4;
			*startArm = NORTH;
			*goalArm = WEST;
			*otherArm = SOUTH;
		}
	} else {
		Debug('e', "SymmetryExperimenterPlugin::GetTrialInformation: Bad flag: ", mode);
	}
}

/* Function to run a single trial for the rat. Assumes that the maze has been baited for the
   trial, that the start arm door has been opened, and that the rat has been place in the
   center of the maze. This function will wait until the rat enters the start arm, and consumes
   the food down that arm.
   Returns true if rat found food on trial. False otherwise.
*/
bool SymmetryExperimenterPlugin::RunTrialProtocol(int startArm, int goalArm, int otherArm, int trialType)
{
	// Wait until the rat has consumed the food down the start arm
	EXP_WaitForEvent(EXP_EVENT_FOOD_CONSUMED);
	Debug('n', "SymmetryExperimenterPlugin::RunTrialProtocol: Rat ate food");
	
	// Open the goal (correct choice) arm and the incorrect choice arm.
	EXP_OpenDoor(goalArm);
	EXP_OpenDoor(otherArm);
	
	// Wait until rat enters center of maze, and close the start arm door
	EXP_WaitForEvent(EXP_EVENT_CENTER_ENTRY);
	Debug('n', "SymmetryExperimenterPlugin::RunTrialProtocol: Rat entered center of maze");
	EXP_CloseDoor(startArm);
	
	// Wait until rat enters an arm: May be goalArm or otherArm
	int actualChoiceArm = EXP_WaitForEvent(EXP_EVENT_ARM_ENTRY);
	Debug('n', "SymmetryExperimenterPlugin::RunTrialProtocol: Rat made arm choice");
	int armToClose;
	
	if (actualChoiceArm == goalArm) {
		armToClose = otherArm;
	} else {
		armToClose = goalArm;
	}
	
	// Close other arm. Rat only gets one arm choice per trial
	EXP_CloseDoor(armToClose);
	
	// Rat will next either eat the food down the arm because he made a correct arm choice
	// or will enter the center of the maze. Wait for this event.
	int eventVal = EXP_WaitForEvent(EXP_EVENT_CENTER_ENTRY | EXP_EVENT_FOOD_CONSUMED);
	Debug('n', "SymmetryExperimenterPlugin::RunTrialProtocol: Rat entered center or ate food");
	
	// Rat entered center of maze or ate food, so take him off the maze	
	EXP_TakeRatOffMaze();
	
	bool foundFood = 0;
	
	if (eventVal != CENTER) { // rat did not enter center, therefore, he ate the food
		// Rat entered correct arm and ate food
		AddPerformanceRecord(trialType, actualChoiceArm, 1, 1);
		foundFood = 1;
	} else { // rat entered center
		if (actualChoiceArm == goalArm) {
			// Rat entered correct arm but did not eat food
			AddPerformanceRecord(trialType, actualChoiceArm, 1, 0);
		} else {
			// Rat entered incorrect arm and is just going back to center.
			AddPerformanceRecord(trialType, actualChoiceArm, 0, 0);
		}
	}
	
	return foundFood;
}

bool SymmetryExperimenterPlugin::TrainCriterionSatisified()
{
	int numTrials = perfQueue.CurrLength();
	
	if (numTrials < PERFORMANCE_WINDOW) {
		// can only satisfy criterion if we've run a certain number of trials
		return false;
	}

	// Performance in the queue is the number of trials where the rat found the food.
	// Therfore, this is the number of correct trials in the queue.
	float fractionCorrect = ((float) perfQueue.Performance() / (float) numTrials);
	return (fractionCorrect >= PERFORMANCE_CRITERION);
}

bool SymmetryExperimenterPlugin::TestCriterionSatisified()
{
	return TrainCriterionSatisified();
}

// Run next training or testing trial. Returns true iff we've finished all training
// and testing for the current rat.
bool SymmetryExperimenterPlugin::RunNextTrial()
{
	bool doneRat = false;
	int startArm, goalArm, otherArm;
	int trialType;
	char buffer[50];
	
	GetTrialInformation(&startArm, &goalArm, &otherArm, &trialType);
	SetupForTrial(startArm, goalArm, otherArm);

	sprintf(buffer, "%d", startArm);
	Debug('n', "SymmetryExperimenterPlugin::RunNextTrial: startArm= ", buffer);
	sprintf(buffer, "%d", goalArm);
	Debug('n', "SymmetryExperimenterPlugin::RunNextTrial: goalArm= ", buffer);
	sprintf(buffer, "%d", otherArm);
	Debug('n', "SymmetryExperimenterPlugin::RunNextTrial: otherArm= ", buffer);
	
	// send rat trial type, just for evaluation of model at rat	
	EXP_PutRatOnMaze(trialType);
	
	bool foundFood = RunTrialProtocol(startArm, goalArm, otherArm, trialType);
	perfQueue.Enqueue(foundFood);
	
	// RunTrialProtocol already took the rat off the maze
	
	EXP_PlaceRatAt(CENTER); // reposition in center, just for aesthetics

	trialNumber++;
	
	if (strcmp(mode, "Train") == 0) {
		// Switch from training to testing?
		if ((trialNumber > GetMaxTrainTrial()) || TrainCriterionSatisified()) {
			mode = "Test";
			trialNumber = 1;
			perfQueue.Reset(); // clear training data from performance queue
		}
	} else if (strcmp(mode, "Test") == 0) {
		// Finished testing?
		if ((trialNumber > GetMaxTestTrial()) || TestCriterionSatisified()) {
			doneRat = true;
		}
	} else {
		Debug('e', "SymmetryExperimenterPlugin::RunNextTrial: Error: Bad mode= ", mode);
		doneRat = true;
	}
	
	return doneRat;
}

void SymmetryExperimenterPlugin::PLUGIN_Setup()
{
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_Setup: Start");
	
	ratNo = 1;
	trialNumber = 1;
	EXP_SetCurrentRatNumber(ratNo);
	mode = "Train";
	
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_Setup: End");
}

void SymmetryExperimenterPlugin::PLUGIN_RunTrial()
{
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_RunTrial: Start");
	
	if (RunNextTrial()) {
		// Done current rat
		trialNumber = 1;
		ratNo++;
		// TO DO: Check to see if we've done all the rats (need database call)
		EXP_SetCurrentRatNumber(ratNo);
	}
	
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_RunTrial: End");
}

// Complete trials for current rat
void SymmetryExperimenterPlugin::PLUGIN_RunCurrentRat()
{
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_RunCurrentRat: Start");

	for (;;) {
		if (RunNextTrial()) {
			// Rat is done all trials
			break;
		}
		
		EXP_Wait(0.4); // separate the trials by a pause
	}
	
	// Done current rat
	ratNo++;
	trialNumber = 1;
	
	// TO DO: Check to see if we've done all the rats (need database call)
	EXP_SetCurrentRatNumber(ratNo);
	
	Debug('n', "SymmetryExperimenterPlugin::PLUGIN_RunCurrentRat: End");
}

/*
void SymmetryExperimenterPlugin::PLUGIN_RunStep()
{
	Debug('e', "SymmetryExperimenterPlugin::PLUGIN_RunStep: Not Implemented");
}
*/

// These event handlers should not be called as we're doing explicit event handling
// of incoming rat events with this experimenter.
void SymmetryExperimenterPlugin::PLUGIN_RatConsumedFood()
{
	Debug('e', "SymmetryExperimenterPlugin::PLUGIN_RatConsumedFood: ERROR: Should not receive");
}

void SymmetryExperimenterPlugin::PLUGIN_RatMoved(BMessage *msg)
{
	Debug('e', "SymmetryExperimenterPlugin::PLUGIN_RatMoved: ERROR: Should not receive");
} 

// SymmetryExperimenterPlugin.cpp ends

